{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.sql import SparkSession\n",
    "spark=SparkSession.builder.appName('log_reg').getOrCreate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "df=spark.read.csv('log_reg.csv',inferSchema=True,header=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.sql.functions import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(15056, 6)\n"
     ]
    }
   ],
   "source": [
    "print((df.count(),len(df.columns)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- 城市: string (nullable = true)\n",
      " |-- 年龄: integer (nullable = true)\n",
      " |-- 是否第一次访问: integer (nullable = true)\n",
      " |-- 电商平台: string (nullable = true)\n",
      " |-- 搜索产品数: integer (nullable = true)\n",
      " |-- 是否购买: integer (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----+----+--------------+--------+----------+--------+\n",
      "|城市|年龄|是否第一次访问|电商平台|搜索产品数|是否购买|\n",
      "+----+----+--------------+--------+----------+--------+\n",
      "|广州|  38|             1|  TaoBao|        12|       0|\n",
      "|深圳|  29|             1|  TaoBao|        11|       0|\n",
      "|深圳|  37|             1|JingDong|         2|       0|\n",
      "|上海|  32|             1|  SuNing|         1|       0|\n",
      "|北京|  33|             1|JingDong|         3|       0|\n",
      "+----+----+--------------+--------+----------+--------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------+-----+-----------------+------------------+--------+-----------------+-------------------+\n",
      "|summary| 城市|             年龄|    是否第一次访问|电商平台|       搜索产品数|           是否购买|\n",
      "+-------+-----+-----------------+------------------+--------+-----------------+-------------------+\n",
      "|  count|15056|            15056|             15056|   15056|            15056|              15056|\n",
      "|   mean| null|28.51268597236982|0.5001992561105207|    null|9.529357066950054|0.49787460148777896|\n",
      "| stddev| null|7.878922594747357|0.5000165657987499|    null|6.075813641256771| 0.5000120880138007|\n",
      "|    min| 上海|               14|                 0|JingDong|                1|                  0|\n",
      "|    max| 深圳|              111|                 1|  TaoBao|               29|                  1|\n",
      "+-------+-----+-----------------+------------------+--------+-----------------+-------------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.describe().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----+-----+\n",
      "|城市|count|\n",
      "+----+-----+\n",
      "|深圳| 1947|\n",
      "|上海| 9184|\n",
      "|北京|  887|\n",
      "|广州| 3038|\n",
      "+----+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.groupBy('城市').count().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------+-----+\n",
      "|电商平台|count|\n",
      "+--------+-----+\n",
      "|JingDong| 4321|\n",
      "|  TaoBao| 7432|\n",
      "|  SuNing| 3303|\n",
      "+--------+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.groupBy('电商平台').count().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------+-----+\n",
      "|是否购买|count|\n",
      "+--------+-----+\n",
      "|       1| 7496|\n",
      "|       0| 7560|\n",
      "+--------+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.groupBy('是否购买').count().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----+------------------+-------------------+------------------+-------------------+\n",
      "|城市|         avg(年龄)|avg(是否第一次访问)|   avg(搜索产品数)|      avg(是否购买)|\n",
      "+----+------------------+-------------------+------------------+-------------------+\n",
      "|深圳| 30.26707755521315| 0.3261427837699024| 4.927067282999486|0.03697996918335902|\n",
      "|上海|28.408101045296167|  0.518074912891986| 9.976698606271777| 0.5431184668989547|\n",
      "|北京|27.704622322435174|  0.560315670800451|11.082299887260428|  0.644870349492672|\n",
      "|广州| 27.94042132982225| 0.5401579986833444|10.673140223831467| 0.6135615536537196|\n",
      "+----+------------------+-------------------+------------------+-------------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.groupBy('城市').mean().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------+------------------+-------------------+-----------------+-------------------+\n",
      "|电商平台|         avg(年龄)|avg(是否第一次访问)|  avg(搜索产品数)|      avg(是否购买)|\n",
      "+--------+------------------+-------------------+-----------------+-------------------+\n",
      "|JingDong|28.370284656329552| 0.5123813931960194|9.809766257810692| 0.5223327933348761|\n",
      "|  TaoBao|28.560145317545746|  0.503767491926803|9.515473627556512| 0.5006727664155005|\n",
      "|  SuNing|28.592188919164396| 0.4762337269149258|9.193763245534363|0.45958219800181654|\n",
      "+--------+------------------+-------------------+-----------------+-------------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.groupBy('电商平台').mean().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import StringIndexer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "commerce_indexer = StringIndexer(inputCol=\"电商平台\", outputCol=\"电商平台索引\").fit(df)\n",
    "df = commerce_indexer.transform(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----+----+--------------+--------+----------+--------+------------+\n",
      "|城市|年龄|是否第一次访问|电商平台|搜索产品数|是否购买|电商平台索引|\n",
      "+----+----+--------------+--------+----------+--------+------------+\n",
      "|广州|  38|             1|  TaoBao|        12|       0|         0.0|\n",
      "|深圳|  29|             1|  TaoBao|        11|       0|         0.0|\n",
      "|深圳|  37|             1|JingDong|         2|       0|         1.0|\n",
      "|上海|  32|             1|  SuNing|         1|       0|         2.0|\n",
      "|北京|  33|             1|JingDong|         3|       0|         1.0|\n",
      "|深圳|  33|             1|JingDong|        18|       0|         1.0|\n",
      "|深圳|  29|             0|JingDong|         9|       0|         1.0|\n",
      "|上海|  24|             1|JingDong|         5|       1|         1.0|\n",
      "|上海|  29|             0|  TaoBao|        10|       0|         0.0|\n",
      "|上海|  28|             0|  SuNing|         3|       1|         2.0|\n",
      "+----+----+--------------+--------+----------+--------+------------+\n",
      "only showing top 10 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.show(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import OneHotEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "commerce_vector = OneHotEncoder(inputCol=\"电商平台索引\", outputCol=\"电商平台索引向量\")\n",
    "df = commerce_vector.transform(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----+----+--------------+--------+----------+--------+------------+----------------+\n",
      "|城市|年龄|是否第一次访问|电商平台|搜索产品数|是否购买|电商平台索引|电商平台索引向量|\n",
      "+----+----+--------------+--------+----------+--------+------------+----------------+\n",
      "|广州|38  |1             |TaoBao  |12        |0       |0.0         |(2,[0],[1.0])   |\n",
      "|深圳|29  |1             |TaoBao  |11        |0       |0.0         |(2,[0],[1.0])   |\n",
      "|深圳|37  |1             |JingDong|2         |0       |1.0         |(2,[1],[1.0])   |\n",
      "+----+----+--------------+--------+----------+--------+------------+----------------+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.show(3,False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------+-----+\n",
      "|电商平台索引向量|count|\n",
      "+----------------+-----+\n",
      "|(2,[0],[1.0])   |7432 |\n",
      "|(2,[1],[1.0])   |4321 |\n",
      "|(2,[],[])       |3303 |\n",
      "+----------------+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.groupBy('电商平台索引向量').count().orderBy('count',ascending=False).show(5,False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------+-----+\n",
      "|电商平台索引向量|count|\n",
      "+----------------+-----+\n",
      "|(2,[0],[1.0])   |7432 |\n",
      "|(2,[1],[1.0])   |4321 |\n",
      "|(2,[],[])       |3303 |\n",
      "+----------------+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.groupBy('电商平台索引向量').count().orderBy('count',ascending=False).show(5,False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----+----+--------------+--------+----------+--------+------------+----------------+--------+-------------+\n",
      "|城市|年龄|是否第一次访问|电商平台|搜索产品数|是否购买|电商平台索引|电商平台索引向量|城市索引| 城市索引向量|\n",
      "+----+----+--------------+--------+----------+--------+------------+----------------+--------+-------------+\n",
      "|广州|  38|             1|  TaoBao|        12|       0|         0.0|   (2,[0],[1.0])|     1.0|(3,[1],[1.0])|\n",
      "|深圳|  29|             1|  TaoBao|        11|       0|         0.0|   (2,[0],[1.0])|     2.0|(3,[2],[1.0])|\n",
      "|深圳|  37|             1|JingDong|         2|       0|         1.0|   (2,[1],[1.0])|     2.0|(3,[2],[1.0])|\n",
      "|上海|  32|             1|  SuNing|         1|       0|         2.0|       (2,[],[])|     0.0|(3,[0],[1.0])|\n",
      "|北京|  33|             1|JingDong|         3|       0|         1.0|   (2,[1],[1.0])|     3.0|    (3,[],[])|\n",
      "|深圳|  33|             1|JingDong|        18|       0|         1.0|   (2,[1],[1.0])|     2.0|(3,[2],[1.0])|\n",
      "|深圳|  29|             0|JingDong|         9|       0|         1.0|   (2,[1],[1.0])|     2.0|(3,[2],[1.0])|\n",
      "|上海|  24|             1|JingDong|         5|       1|         1.0|   (2,[1],[1.0])|     0.0|(3,[0],[1.0])|\n",
      "|上海|  29|             0|  TaoBao|        10|       0|         0.0|   (2,[0],[1.0])|     0.0|(3,[0],[1.0])|\n",
      "|上海|  28|             0|  SuNing|         3|       1|         2.0|       (2,[],[])|     0.0|(3,[0],[1.0])|\n",
      "|北京|  24|             1|JingDong|         4|       0|         1.0|   (2,[1],[1.0])|     3.0|    (3,[],[])|\n",
      "|上海|  26|             0|  TaoBao|        17|       1|         0.0|   (2,[0],[1.0])|     0.0|(3,[0],[1.0])|\n",
      "|上海|  30|             1|  TaoBao|        21|       1|         0.0|   (2,[0],[1.0])|     0.0|(3,[0],[1.0])|\n",
      "|上海|  32|             1|  SuNing|         4|       0|         2.0|       (2,[],[])|     0.0|(3,[0],[1.0])|\n",
      "|广州|  24|             1|  TaoBao|         2|       0|         0.0|   (2,[0],[1.0])|     1.0|(3,[1],[1.0])|\n",
      "|上海|  32|             1|  SuNing|        23|       1|         2.0|       (2,[],[])|     0.0|(3,[0],[1.0])|\n",
      "|上海|  16|             0|  TaoBao|         8|       1|         0.0|   (2,[0],[1.0])|     0.0|(3,[0],[1.0])|\n",
      "|上海|  14|             0|  TaoBao|         4|       0|         0.0|   (2,[0],[1.0])|     0.0|(3,[0],[1.0])|\n",
      "|北京|  43|             0|  TaoBao|         8|       1|         0.0|   (2,[0],[1.0])|     3.0|    (3,[],[])|\n",
      "|上海|  21|             0|JingDong|         5|       0|         1.0|   (2,[1],[1.0])|     0.0|(3,[0],[1.0])|\n",
      "+----+----+--------------+--------+----------+--------+------------+----------------+--------+-------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import VectorAssembler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_assembler = VectorAssembler(inputCols=['电商平台索引向量','城市索引向量','年龄', '是否第一次访问','搜索产品数'], outputCol=\"特征值向量\")\n",
    "df = df_assembler.transform(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- 城市: string (nullable = true)\n",
      " |-- 年龄: integer (nullable = true)\n",
      " |-- 是否第一次访问: integer (nullable = true)\n",
      " |-- 电商平台: string (nullable = true)\n",
      " |-- 搜索产品数: integer (nullable = true)\n",
      " |-- 是否购买: integer (nullable = true)\n",
      " |-- 电商平台索引: double (nullable = false)\n",
      " |-- 电商平台索引向量: vector (nullable = true)\n",
      " |-- 城市索引: double (nullable = false)\n",
      " |-- 城市索引向量: vector (nullable = true)\n",
      " |-- 特征值向量: vector (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------------------------------+--------+\n",
      "|特征值向量                         |是否购买|\n",
      "+-----------------------------------+--------+\n",
      "|[1.0,0.0,0.0,1.0,0.0,38.0,1.0,12.0]|0       |\n",
      "|[1.0,0.0,0.0,0.0,1.0,29.0,1.0,11.0]|0       |\n",
      "|[0.0,1.0,0.0,0.0,1.0,37.0,1.0,2.0] |0       |\n",
      "|(8,[2,5,6,7],[1.0,32.0,1.0,1.0])   |0       |\n",
      "|(8,[1,5,6,7],[1.0,33.0,1.0,3.0])   |0       |\n",
      "|[0.0,1.0,0.0,0.0,1.0,33.0,1.0,18.0]|0       |\n",
      "|(8,[1,4,5,7],[1.0,1.0,29.0,9.0])   |0       |\n",
      "|[0.0,1.0,1.0,0.0,0.0,24.0,1.0,5.0] |1       |\n",
      "|(8,[0,2,5,7],[1.0,1.0,29.0,10.0])  |0       |\n",
      "|(8,[2,5,7],[1.0,28.0,3.0])         |1       |\n",
      "+-----------------------------------+--------+\n",
      "only showing top 10 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.select(['特征值向量','是否购买']).show(10,False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_df=df.select(['特征值向量','是否购买'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.classification import LogisticRegression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_df,test_df=model_df.randomSplit([0.7,0.3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10490"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "training_df.count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------+-----+\n",
      "|是否购买|count|\n",
      "+--------+-----+\n",
      "|       1| 5239|\n",
      "|       0| 5251|\n",
      "+--------+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "training_df.groupBy('是否购买').count().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4566"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_df.count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------+-----+\n",
      "|是否购买|count|\n",
      "+--------+-----+\n",
      "|       1| 2257|\n",
      "|       0| 2309|\n",
      "+--------+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "test_df.groupBy('是否购买').count().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_reg=LogisticRegression(featuresCol='特征值向量',labelCol='是否购买').fit(training_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pyspark.ml import Estimator\n",
    "from pyspark.ml import Transformer\n",
    "isinstance(log_reg,Transformer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_results=log_reg.evaluate(training_df).predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------+----------+----------------------------------------+\n",
      "|是否购买|prediction|probability                             |\n",
      "+--------+----------+----------------------------------------+\n",
      "|1       |1.0       |[0.2967915035494824,0.7032084964505176] |\n",
      "|1       |1.0       |[0.1681037900675889,0.8318962099324111] |\n",
      "|1       |1.0       |[0.1681037900675889,0.8318962099324111] |\n",
      "|1       |1.0       |[0.08821488317232305,0.9117851168276769]|\n",
      "|1       |1.0       |[0.08821488317232305,0.9117851168276769]|\n",
      "|1       |1.0       |[0.04427156521457076,0.9557284347854291]|\n",
      "|1       |1.0       |[0.04427156521457076,0.9557284347854291]|\n",
      "|1       |1.0       |[0.04427156521457076,0.9557284347854291]|\n",
      "|1       |1.0       |[0.02169724782921141,0.9783027521707887]|\n",
      "|1       |1.0       |[0.02169724782921141,0.9783027521707887]|\n",
      "+--------+----------+----------------------------------------+\n",
      "only showing top 10 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "train_results.filter(train_results['是否购买']==1).filter(train_results['prediction']==1).select(['是否购买','prediction','probability']).show(10,False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Probability at 0 index is for 0 class and probabilty as 1 index is for 1 class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "correct_preds=train_results.filter(train_results['是否购买']==1).filter(train_results['prediction']==1).count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5239"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "training_df.filter(training_df['是否购买']==1).count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9366291276961252"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "float(correct_preds)/(training_df.filter(training_df['是否购买']==1).count())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "预测准确率：0.939084842707\n"
     ]
    }
   ],
   "source": [
    "print('{}{}'.format('预测准确率：',log_reg.evaluate(training_df).accuracy) )  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#Test Set results\n",
    "from pyspark.ml.evaluation import BinaryClassificationEvaluator\n",
    "isinstance(log_reg.evaluate,BinaryClassificationEvaluator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "results=log_reg.evaluate(test_df).predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------+----------+\n",
      "|是否购买|prediction|\n",
      "+--------+----------+\n",
      "|0       |0.0       |\n",
      "|0       |0.0       |\n",
      "|0       |0.0       |\n",
      "|0       |0.0       |\n",
      "|0       |0.0       |\n",
      "|1       |0.0       |\n",
      "|0       |1.0       |\n",
      "|1       |1.0       |\n",
      "|1       |1.0       |\n",
      "|1       |1.0       |\n",
      "+--------+----------+\n",
      "only showing top 10 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "results.select(['是否购买','prediction']).show(10,False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- 特征值向量: vector (nullable = true)\n",
      " |-- 是否购买: integer (nullable = true)\n",
      " |-- rawPrediction: vector (nullable = true)\n",
      " |-- probability: vector (nullable = true)\n",
      " |-- prediction: double (nullable = false)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "results.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DataFrame[特征值向量: vector, 是否购买: int, rawPrediction: vector, probability: vector, prediction: double]"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pyspark.ml.evaluation import BinaryClassificationEvaluator\n",
    "results[results['是否购买']==1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "#confusion matrix\n",
    "true_postives = results[(results['是否购买'] == 1) & (results.prediction == 1)].count()\n",
    "true_negatives = results[(results['是否购买'] == 0) & (results.prediction == 0)].count()\n",
    "false_positives = results[(results['是否购买'] == 0) & (results.prediction == 1)].count()\n",
    "false_negatives = results[(results['是否购买'] == 1) & (results.prediction == 0)].count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2121\n",
      "2158\n",
      "151\n",
      "136\n",
      "4566\n",
      "4566\n"
     ]
    }
   ],
   "source": [
    "print (true_postives)\n",
    "print (true_negatives)\n",
    "print (false_positives)\n",
    "print (false_negatives)\n",
    "print(true_postives+true_negatives+false_positives+false_negatives)\n",
    "print (results.count())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.93974302171\n"
     ]
    }
   ],
   "source": [
    "recall = float(true_postives)/(true_postives + false_negatives)\n",
    "print(recall)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.933538732394\n"
     ]
    }
   ],
   "source": [
    "precision = float(true_postives) / (true_postives + false_positives)\n",
    "print(precision)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0\n"
     ]
    }
   ],
   "source": [
    "accuracy=float((true_postives+true_negatives) /(results.count()))\n",
    "print(accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4566"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results.count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4279"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "true_postives+true_negatives"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
